import os.path
import matplotlib.pyplot as plt


def dataset_visual(dataset_path="dataset"):
    train_path = os.path.join(dataset_path, 'train')
    test_path = os.path.join(dataset_path, 'test')
    labels = os.listdir(train_path)
    train_count, test_count = [], []
    for label in labels:
        train_count.append((os.listdir(os.path.join(train_path, label))).__len__())
        test_count.append((os.listdir(os.path.join(test_path, label))).__len__())
    x = range(0, labels.__len__())
    plt.bar(x, train_count, tick_label=labels)
    plt.bar(x, test_count)
    plt.show()


dataset_visual(r"D:\WorkSpace\Animals10\dataset")